fix(engine): per-engine threads to eliminate cross-engine stream contamination#1304
Conversation
0760768 to
0835164
Compare
|
Thanks for putting this together, the analysis is spot on. The per-engine stream isolation is clearly the right fix for the cross-engine contamination, and the writeup with the concurrent throughput numbers and the regex-based tests is really thoughtful work. My only hesitation is the scope. This touches the scheduler core and the MTP patch with 37 stream references rewritten, and I want to give it more soak time than the 0.3.9 release window allows. So I am going to hold this out of the 0.3.9 stable release and pull it in starting from the next dev build instead. I know you have been waiting on this, sorry for the delay and thanks for your patience. The change itself looks solid, I just want to land it where it can get proper testing before it hits a stable release. |
…amination Replace the shared _global_mlx_executor with per-EngineCore ThreadPoolExecutor + mx.Stream, and fix the MTP patch reading the module-level generation_stream instead of the per-engine stream.
0835164 to
587d77c
Compare
|
Rebased locally to absorb two small fixups: dropped the 5th The MTP-via-enclosing-context claim checks out — Thanks again for the careful writeup and the perf numbers. |
BatchGenerator.__init__ already calls mx.set_wired_limit() on each instance, and concurrent calls with the same value are race-free (verified empirically). The guard never prevented the race it claimed to fix. Follow-up to #1304.
mlx-vlm's load() only materializes model.language_model.parameters(), leaving frozen buffers (RoPE freqs) and sibling sub-trees (vision_tower, audio_tower) as lazy arrays bound to the loader thread's default stream. Pre-#1304 this was invisible because loader and forward shared one global thread; the per-engine executor split exposed it as "no Stream(gpu, X) in current thread" when mx.eval touches a sibling buffer during prefill. Fix: materialize the full model tree on the loader thread right after load. Verified against gemma-4-E2B-it and gemma-4-31b-it-4bit. Also fixes test_safe_sync_passes_generation_stream to match the _default_generation_stream alias introduced by #1304. Reported by @zviratko in #1304.
) #1304 (``fix(engine): per-engine threads to eliminate cross-engine stream contamination``) refactored the patched ``BatchGenerator`` to inherit its execution stream from the enclosing engine context, and removed the module-level ``_get_generation_stream`` helper as part of that. ``TestBatchGeneratorDispatch._make_reconcile_batch`` still tried to monkeypatch that name and failed at collection of every test that depends on the fixture with ``AttributeError: module ... has no attribute '_get_generation_stream'`` — taking down 4 reconcile-path tests on every CI run. The override is no longer needed: the surrounding fixture replaces ``_rebuild_singleton_cache`` and ``_call_backbone`` with fakes that do all of their work via ``np.array`` / ``mx.array`` directly, so neither MLX dispatch nor stream selection is reached. Tests (tests/test_mlx_lm_mtp_patch.py::TestBatchGeneratorDispatch): - test_reconcile_uses_queue_front_as_next_token - test_reconcile_empty_queue_samples_from_logits - test_reconcile_returns_false_on_empty_tokens - test_reconcile_fallback_on_rebuild_failure 11/11 TestBatchGeneratorDispatch tests pass.
Catch-up sync. Highlights: - boundary-store cleanup race fix (jundot#1423) — eliminates the test_cleanup_all_drains_queue flake we have been carrying. - per-engine MLX threads / streams (jundot#1304) — multiple models stepping scheduler.step() concurrently no longer cross-contaminate streams. - VLM lazy state materialized on loader thread; skip MTPModule attach when checkpoint lacks mtp.* weights. - Dead TieredCacheManager removed; profiles three-scope template refactor (jundot#1399). 4567 pass / 3 known env-override fails / 36 skip. Zero regression. User delegated review. --- Catch-up 同步. 主要内容: - boundary-store 清理 race 修复 (jundot#1423) - 消掉一直拖着的 test_cleanup_all_drains_queue flake. - 每引擎 MLX 线程 / 流 (jundot#1304) - 多模型并发 scheduler.step() 不再 cross-contaminate. - VLM 在 loader 线程实例化 lazy 状态; checkpoint 无 mtp.* 权重时 跳过 MTPModule attach. - 删 dead TieredCacheManager; profiles three-scope template 重构 (jundot#1399). 4567 pass / 3 known env-override fails / 36 skip. 零回归. 用户委托 review.
Same root cause and fix as e93c408 for the VLM MTP drafter, applied to the SpecPrefill draft load path in both BatchedEngine and VLMBatchedEngine. `mlx_lm.load(specprefill_draft)` materializes only model.parameters() via mx.eval and leaves frozen buffers (RoPE freqs, masked_embedding tables, etc.) lazy, bound to the global mlx loader thread's stream. The draft model is then stored on `Scheduler._specprefill_draft_model` and later read from `score_tokens(self._specprefill_draft_model, ...)` at scheduler.py:4247, which runs on the per-engine executor's worker thread. The first SpecPrefill-eligible prefill raises ``no Stream(gpu, X) in current thread`` because `mx.async_eval` on the inference thread tries to materialize lazy ops against a stream that does not exist there. Call materialize_lazy_state(draft_model) inside _load_draft (which runs on get_mlx_executor()) right before returning, so every leaf array is concrete before any inference thread reads it. The VLM helper also refactors the dual-return into a single materialize+return to cover both the custom_quantization and standard mlx_lm.load paths. Closes the SpecPrefill instance of the jundot#1304 per-engine-threads bug class identified after 9d5bed8 (main VLM model), e93c408 (VLM MTP drafter), and 9407468 (boundary snapshots). Not directly testable without running SpecPrefill against a per-engine configuration, but matches the merged e93c408 pattern exactly.
Same root cause and fix as e93c408 for the VLM MTP drafter, applied to the SpecPrefill draft load path in both BatchedEngine and VLMBatchedEngine. `mlx_lm.load(specprefill_draft)` materializes only model.parameters() via mx.eval and leaves frozen buffers (RoPE freqs, masked_embedding tables, etc.) lazy, bound to the global mlx loader thread's stream. The draft model is then stored on `Scheduler._specprefill_draft_model` and later read from `score_tokens(self._specprefill_draft_model, ...)` at scheduler.py:4247, which runs on the per-engine executor's worker thread. The first SpecPrefill-eligible prefill raises ``no Stream(gpu, X) in current thread`` because `mx.async_eval` on the inference thread tries to materialize lazy ops against a stream that does not exist there. Call materialize_lazy_state(draft_model) inside _load_draft (which runs on get_mlx_executor()) right before returning, so every leaf array is concrete before any inference thread reads it. The VLM helper also refactors the dual-return into a single materialize+return to cover both the custom_quantization and standard mlx_lm.load paths. Closes the SpecPrefill instance of the jundot#1304 per-engine-threads bug class identified after 9d5bed8 (main VLM model), e93c408 (VLM MTP drafter), and 9407468 (boundary snapshots). Not directly testable without running SpecPrefill against a per-engine configuration, but matches the merged e93c408 pattern exactly.
Same root cause and fix as e93c408 for the VLM MTP drafter, applied to the SpecPrefill draft load path in both BatchedEngine and VLMBatchedEngine. `mlx_lm.load(specprefill_draft)` materializes only model.parameters() via mx.eval and leaves frozen buffers (RoPE freqs, masked_embedding tables, etc.) lazy, bound to the global mlx loader thread's stream. The draft model is then stored on `Scheduler._specprefill_draft_model` and later read from `score_tokens(self._specprefill_draft_model, ...)` at scheduler.py:4247, which runs on the per-engine executor's worker thread. The first SpecPrefill-eligible prefill raises ``no Stream(gpu, X) in current thread`` because `mx.async_eval` on the inference thread tries to materialize lazy ops against a stream that does not exist there. Call materialize_lazy_state(draft_model) inside _load_draft (which runs on get_mlx_executor()) right before returning, so every leaf array is concrete before any inference thread reads it. The VLM helper also refactors the dual-return into a single materialize+return to cover both the custom_quantization and standard mlx_lm.load paths. Closes the SpecPrefill instance of the jundot#1304 per-engine-threads bug class identified after 9d5bed8 (main VLM model), e93c408 (VLM MTP drafter), and 9407468 (boundary snapshots). Not directly testable without running SpecPrefill against a per-engine configuration, but matches the merged e93c408 pattern exactly.
Same root cause and fix as e93c408 for the VLM MTP drafter, applied to the SpecPrefill draft load path in both BatchedEngine and VLMBatchedEngine. `mlx_lm.load(specprefill_draft)` materializes only model.parameters() via mx.eval and leaves frozen buffers (RoPE freqs, masked_embedding tables, etc.) lazy, bound to the global mlx loader thread's stream. The draft model is then stored on `Scheduler._specprefill_draft_model` and later read from `score_tokens(self._specprefill_draft_model, ...)` at scheduler.py:4247, which runs on the per-engine executor's worker thread. The first SpecPrefill-eligible prefill raises ``no Stream(gpu, X) in current thread`` because `mx.async_eval` on the inference thread tries to materialize lazy ops against a stream that does not exist there. Call materialize_lazy_state(draft_model) inside _load_draft (which runs on get_mlx_executor()) right before returning, so every leaf array is concrete before any inference thread reads it. The VLM helper also refactors the dual-return into a single materialize+return to cover both the custom_quantization and standard mlx_lm.load paths. Closes the SpecPrefill instance of the jundot#1304 per-engine-threads bug class identified after 9d5bed8 (main VLM model), e93c408 (VLM MTP drafter), and 9407468 (boundary snapshots). Not directly testable without running SpecPrefill against a per-engine configuration, but matches the merged e93c408 pattern exactly.
Same root cause and fix as e93c408 for the VLM MTP drafter, applied to the SpecPrefill draft load path in both BatchedEngine and VLMBatchedEngine. `mlx_lm.load(specprefill_draft)` materializes only model.parameters() via mx.eval and leaves frozen buffers (RoPE freqs, masked_embedding tables, etc.) lazy, bound to the global mlx loader thread's stream. The draft model is then stored on `Scheduler._specprefill_draft_model` and later read from `score_tokens(self._specprefill_draft_model, ...)` at scheduler.py:4247, which runs on the per-engine executor's worker thread. The first SpecPrefill-eligible prefill raises ``no Stream(gpu, X) in current thread`` because `mx.async_eval` on the inference thread tries to materialize lazy ops against a stream that does not exist there. Call materialize_lazy_state(draft_model) inside _load_draft (which runs on get_mlx_executor()) right before returning, so every leaf array is concrete before any inference thread reads it. The VLM helper also refactors the dual-return into a single materialize+return to cover both the custom_quantization and standard mlx_lm.load paths. Closes the SpecPrefill instance of the jundot#1304 per-engine-threads bug class identified after 9d5bed8 (main VLM model), e93c408 (VLM MTP drafter), and 9407468 (boundary snapshots). Not directly testable without running SpecPrefill against a per-engine configuration, but matches the merged e93c408 pattern exactly.
…#1485) Same root cause and fix as e93c408 for the VLM MTP drafter, applied to the SpecPrefill draft load path in both BatchedEngine and VLMBatchedEngine. `mlx_lm.load(specprefill_draft)` materializes only model.parameters() via mx.eval and leaves frozen buffers (RoPE freqs, masked_embedding tables, etc.) lazy, bound to the global mlx loader thread's stream. The draft model is then stored on `Scheduler._specprefill_draft_model` and later read from `score_tokens(self._specprefill_draft_model, ...)` at scheduler.py:4247, which runs on the per-engine executor's worker thread. The first SpecPrefill-eligible prefill raises ``no Stream(gpu, X) in current thread`` because `mx.async_eval` on the inference thread tries to materialize lazy ops against a stream that does not exist there. Call materialize_lazy_state(draft_model) inside _load_draft (which runs on get_mlx_executor()) right before returning, so every leaf array is concrete before any inference thread reads it. The VLM helper also refactors the dual-return into a single materialize+return to cover both the custom_quantization and standard mlx_lm.load paths. Closes the SpecPrefill instance of the #1304 per-engine-threads bug class identified after 9d5bed8 (main VLM model), e93c408 (VLM MTP drafter), and 9407468 (boundary snapshots). Not directly testable without running SpecPrefill against a per-engine configuration, but matches the merged e93c408 pattern exactly.
…amination (jundot#1304) Replace the shared _global_mlx_executor with per-EngineCore ThreadPoolExecutor + mx.Stream, and fix the MTP patch reading the module-level generation_stream instead of the per-engine stream.
BatchGenerator.__init__ already calls mx.set_wired_limit() on each instance, and concurrent calls with the same value are race-free (verified empirically). The guard never prevented the race it claimed to fix. Follow-up to jundot#1304.
mlx-vlm's load() only materializes model.language_model.parameters(), leaving frozen buffers (RoPE freqs) and sibling sub-trees (vision_tower, audio_tower) as lazy arrays bound to the loader thread's default stream. Pre-jundot#1304 this was invisible because loader and forward shared one global thread; the per-engine executor split exposed it as "no Stream(gpu, X) in current thread" when mx.eval touches a sibling buffer during prefill. Fix: materialize the full model tree on the loader thread right after load. Verified against gemma-4-E2B-it and gemma-4-31b-it-4bit. Also fixes test_safe_sync_passes_generation_stream to match the _default_generation_stream alias introduced by jundot#1304. Reported by @zviratko in jundot#1304.
…ndot#1445) jundot#1304 (``fix(engine): per-engine threads to eliminate cross-engine stream contamination``) refactored the patched ``BatchGenerator`` to inherit its execution stream from the enclosing engine context, and removed the module-level ``_get_generation_stream`` helper as part of that. ``TestBatchGeneratorDispatch._make_reconcile_batch`` still tried to monkeypatch that name and failed at collection of every test that depends on the fixture with ``AttributeError: module ... has no attribute '_get_generation_stream'`` — taking down 4 reconcile-path tests on every CI run. The override is no longer needed: the surrounding fixture replaces ``_rebuild_singleton_cache`` and ``_call_backbone`` with fakes that do all of their work via ``np.array`` / ``mx.array`` directly, so neither MLX dispatch nor stream selection is reached. Tests (tests/test_mlx_lm_mtp_patch.py::TestBatchGeneratorDispatch): - test_reconcile_uses_queue_front_as_next_token - test_reconcile_empty_queue_samples_from_logits - test_reconcile_returns_false_on_empty_tokens - test_reconcile_fallback_on_rebuild_failure 11/11 TestBatchGeneratorDispatch tests pass.
…jundot#1485) Same root cause and fix as e93c408 for the VLM MTP drafter, applied to the SpecPrefill draft load path in both BatchedEngine and VLMBatchedEngine. `mlx_lm.load(specprefill_draft)` materializes only model.parameters() via mx.eval and leaves frozen buffers (RoPE freqs, masked_embedding tables, etc.) lazy, bound to the global mlx loader thread's stream. The draft model is then stored on `Scheduler._specprefill_draft_model` and later read from `score_tokens(self._specprefill_draft_model, ...)` at scheduler.py:4247, which runs on the per-engine executor's worker thread. The first SpecPrefill-eligible prefill raises ``no Stream(gpu, X) in current thread`` because `mx.async_eval` on the inference thread tries to materialize lazy ops against a stream that does not exist there. Call materialize_lazy_state(draft_model) inside _load_draft (which runs on get_mlx_executor()) right before returning, so every leaf array is concrete before any inference thread reads it. The VLM helper also refactors the dual-return into a single materialize+return to cover both the custom_quantization and standard mlx_lm.load paths. Closes the SpecPrefill instance of the jundot#1304 per-engine-threads bug class identified after 9d5bed8 (main VLM model), e93c408 (VLM MTP drafter), and 9407468 (boundary snapshots). Not directly testable without running SpecPrefill against a per-engine configuration, but matches the merged e93c408 pattern exactly.
Summary
When multiple LM engines run concurrently, they share a single
_global_mlx_executorwithmax_workers=1, serializing allscheduler.step()calls through one thread. More critically, the MTP patch reads the module-levelgeneration_streamviasys.modulesfor its forward passes, bypassing whatever stream theBatchGeneratorwas instantiated with. If two MTP-capable engines run simultaneously, their MTP forwards land on the same module-level stream regardless of which engine dispatched them — a stream-ordering violation that upstream'sBatchGenerator(stream=...)parameter (mlx-lm 0.31.3) was designed to prevent.This PR gives each
EngineCoreits ownThreadPoolExecutorandmx.Stream, passes the stream throughSchedulerintoBatchGenerator, and removes the_get_generation_stream()indirection from the MTP patch so MTP operations inherit the correct per-engine stream from the enclosingBatchGeneratorcontext — the same pattern upstream'sGenerationBatch._step()already uses.The global executor is retained for non-LM engines (TTS, STT, embedding, reranker) that still rely on
get_mlx_executor()and_init_mlx_thread.Changes
engine_core.py:EngineCore.__init__creates a per-engineThreadPoolExecutor+mx.new_thread_local_stream()and passes the stream toScheduler.close()shuts down the per-engine executor after scheduler cleanup. Added_ensure_wired_limit()so the process-globalmx.set_wired_limit()runs once rather than racing across concurrentBatchGeneratorinits.scheduler.py:Scheduler.__init__accepts an optionalstreamparameter (falls back to the module-levelgeneration_streamwhen not provided). All 37 internal references togeneration_stream— sync barriers, cache clears,mx.stream()context managers,BatchGeneratorcreation — now useself._stream.batch_generator.py(MTP patch): Removed_get_generation_stream()and the 4 explicitwith mx.stream(...)wrappers that pushed the module-level stream. MTP forwards now inherit the per-engine stream from the enclosingBatchGeneratorcontext, matchingGenerationBatch._step()'s existing pattern.Concurrent throughput
Two models generating simultaneously vs sequentially:
Sub-2x is expected — Metal command buffers still serialize on one GPU. The win is CPU-side overlap (prefill + decode can be submitted concurrently) and eliminating head-of-line blocking where one engine's long prefill stalls another's token emission.
Test plan
tests/test_per_engine_threads.py(10 tests): verifiesSchedulerstores and uses explicit streams, regex-scans theSchedulerclass body for baregeneration_streamreferences, confirms eachEngineCoregets a distinct executor/stream, validates executor shutdown onclose(), and asserts the MTP patch no longer contains_get_generation_streamor anygeneration_streamreference.tests/test_engine_core.py: existing executor tests now assertis not(distinct executors) and concurrent execution (both executors active simultaneously).Related to #1248